import torch
from torch import nn


class PMSELoss(nn.Module):
    def __init__(self, style='pose'):
        super(PMSELoss, self).__init__()
        self.style = style

    def forward(self, input, target):
        loss = None
        if self.style == 'pose':
            loss = torch.mean((input[:, :12] - target[:, :12])**2)
        elif self.style == 'face':
            loss = torch.mean((input-target[:, 12:240])**2)

        return loss

